#!/usr/bin/env python
from config import *

from array import array
from struct import pack, unpack
from collections import deque

import pyaudio
import os
import subprocess
import wave
import urllib2

################################################################################
# Stream Manager
################################################################################

class StreamManager:
    def __init__(self, path, filename, channels=1):
        self.num_channels = channels
        if path[-1] != "/":
            path += "/"
        self.path = path+"wavs/"
        
        self.filenames = []
        self.filenames.append(filename)
        for i in range(self.num_channels):
            self.filenames.append(filename+'_ch'+str(i))

    # Starts the StreamManager
    def start(self):
        with suppress_stdout_stderr():
            self.p = pyaudio.PyAudio()
        device_index = None

        # Find the right audio device
        for device_num in range(self.p.get_device_count()):
            device_info = self.p.get_device_info_by_index(device_num)
            if DEVICE_NAME in device_info['name']:
                device_index = device_num
                print "Using device", device_info['name']

        self.stream = self.p.open(format=FORMAT, channels=self.num_channels, rate=RATE,
                                  input=True, output=True,
                                  input_device_index=device_index,
                                  frames_per_buffer=CHUNK_SIZE)
        
        # mainbuffer stores data that we actively record while VAD is active
        # prebuffer is outdated and only used for recording wav files locally
        # buffer is what is sent to the recognizer and consists of packed byte data
        self.mainbuffer = [array('h')]
        self.prebuffer = [deque(maxlen=CHUNK_SIZE*FRAME_BUFFER_LEN)]
        self.buffer = deque(maxlen=FRAME_BUFFER_LEN)
        for i in range(self.num_channels):
            self.mainbuffer.append(array('h'))
            self.prebuffer.append(deque(maxlen=CHUNK_SIZE*FRAME_BUFFER_LEN/self.num_channels))

        self.current_frame = array('h')

    def record_buffer(self, data):
        self.buffer.append(pack('<' + ('h'*len(data)), *data))

    # Call this whenever we're *not* recording
    def record_prebuffer(self):
        self.save_combined(self.prebuffer)
        self.save_channels(self.prebuffer)

    # Call this whenever we *are* recording
    def record(self):
        self.save_combined(self.mainbuffer)
        self.save_channels(self.mainbuffer)

    # Adds current frame to recording data
    def save_combined(self, buffer_data):
        buffer_data[0].extend(self.current_frame)            

    # Records channel data separately for the current frame
    def save_channels(self, buffer_data):
        for i in range(len(self.current_frame)):
            cur_channel = i % self.num_channels + 1
            buffer_data[cur_channel].append(self.current_frame[i])

    # Shuts down the StreamManager and releases the audio device
    def stop(self):
        self.stream.close()
        self.p.terminate()

    # Returns a frame with channel data interlaced
    def read(self):
        data = self.stream.read(CHUNK_SIZE)
        self.current_frame = unpack('<' + ('h'*(len(data)/2)), data) # little endian, signed short
        self.current_frame = array('h', self.current_frame)

    # Separates the data for the specified channel from the frame
    def channel_frame(self, channel):
        channel_data = array('h')
        for i in range(len(self.current_frame)):
            if i % self.num_channels == channel:
                channel_data.append(self.current_frame[i])
        return channel_data

    # Writes the combined WAV file and individual files for each channel
    def write_files(self):
        self.write_file(self.filenames[0])
        for i in range(self.num_channels):
            self.write_file(self.filenames[i+1], i)
        print "wrote files"

    # Writes just one file and converts it to 16k if it's not already
    def write_file(self, filename, channel=None):
        index = 0
        num_channels = self.num_channels
        if channel != None:
            num_channels = 1
            index = 1+channel

        data = array('h')
        data.extend(self.prebuffer[index])
        data.extend(self.mainbuffer[index])

        filepath = self.path+filename+".wav"
        
        self.write_wav(data, num_channels, filepath)

        if DOWNSAMPLE == True and RATE != DOWNSAMPLE_RATE:
            temp_path = self.path+filename+"_temp.wav"
            subprocess.call(["sox", filepath, "-b", "16", temp_path,
                             "rate", str(DOWNSAMPLE_RATE)])
            subprocess.call(["cp", temp_path, filepath])
            delete_file(temp_path)
            print "downsampled file"

    # Writes a WAV file
    def write_wav(self, data, channels, path):
        sample_width = self.p.get_sample_size(FORMAT)
        data = self.normalize(data)
        data = self.trim(data, channels, END_TRIM_TIME)
        
        data = pack('<' + ('h'*len(data)), *data)
        wf = wave.open(path, 'wb')
        wf.setnchannels(channels)
        wf.setsampwidth(sample_width)
        wf.setframerate(RATE)
        wf.writeframes(data)
        wf.close()

    # Trim off the end of the file
    def trim(self, data, channels, seconds):
        if seconds != 0:
            samples_to_trim = int(seconds*RATE*channels)
            return data[:-1*samples_to_trim]
        return data

    # Average the volume out
    def normalize(self, data):
        MAXIMUM = 16384
        times = float(MAXIMUM)/max(abs(num) for num in data)

        normalized_data = array('h')
        for num in data:
            normalized_data.append(int(num*times))
            
        return normalized_data

    def get_filepath(self, channel=None):
        if channel==None:
            return self.path+self.filenames[0]+".wav"
        else:
            return self.path+self.filenames[channel+1]+".wav"

################################################################################
# Utility Functions
################################################################################

def delete_file(path):
    subprocess.call(["rm", path])

def record_audio(sm, seconds, filename, channel=None):
    sm.start()
    end_time = start_time+seconds
    while time.time() < end_time:
        sm.read()
        sm.record()
    sm.stop()
    sm.write_file(filename, channel)

def play_wav(wav):
    with suppress_stdout_stderr():
        p = pyaudio.PyAudio()
    chunk = 1024
    stream = p.open(format = p.get_format_from_width(wav.getsampwidth()),                    channels = wav.getnchannels(),
                    rate = wav.getframerate(),
                    output = True)
    data = wav.readframes(chunk)

    while data != '':
        stream.write(data)
        data = wav.readframes(chunk)

    stream.close()
    p.terminate()

def synthesize(text):
    inp = text.split()
    inp = "+".join(inp)
    
    response = urllib2.urlopen(SYNTHESIZER_URL+inp)
    wav = response.read()
    f = open("synth.wav", "wb")
    f.write(wav)
    f.close()
    
    wav = wave.open("synth.wav", "rb")
    play_wav(wav)

    print "done synthesizing text"

# I got this from Stack Overflow because the pyaudio errors were annoying
class suppress_stdout_stderr(object):
    '''
    A context manager for doing a "deep suppression" of stdout and stderr in 
    Python, i.e. will suppress all print, even if the print originates in a 
    compiled C/Fortran sub-function.
       This will not suppress raised exceptions, since exceptions are printed
    to stderr just before a script exits, and after the context manager has
    exited (at least, I think that is why it lets exceptions through).      

    '''
    def __init__(self):
        # Open a pair of null files
        self.null_fds =  [os.open(os.devnull,os.O_RDWR) for x in range(2)]
        # Save the actual stdout (1) and stderr (2) file descriptors.
        self.save_fds = (os.dup(1), os.dup(2))

    def __enter__(self):
        # Assign the null pointers to stdout and stderr.
        os.dup2(self.null_fds[0],1)
        os.dup2(self.null_fds[1],2)

    def __exit__(self, *_):
        # Re-assign the real stdout/stderr back to (1) and (2)
        os.dup2(self.save_fds[0],1)
        os.dup2(self.save_fds[1],2)
        # Close the null files
        os.close(self.null_fds[0])
        os.close(self.null_fds[1])
